-
Notifications
You must be signed in to change notification settings - Fork 74
AllToAll implementation #5705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AllToAll implementation #5705
Conversation
|
Review updated until commit 62109c2 Description
|
| Relevant files | |||||||
|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Tensor Layout Assumptions
|
|
!test |
| // For the following communication types, the sharded_id does not have to be | ||
| // outermost in allocation domain. Nonetheless, `tv` still needs to be | ||
| // contiguous and therefore .contiguous() at the beginning of this function. | ||
| // TODO(prmishra): Fix the layout for AllToAll. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on where we relayout the input/output to be compliant with alltoall requirements, this function and potentially ReorderShardedAxisPass will be affected. I will do it in a following PR once the current PR has been reviewed and we agree on the approach for reordering input/output of alltoall
Greptile OverviewGreptile SummaryThis PR implements AllToAll collective communication support for nvfuser's multidevice framework. AllToAll enables simultaneous gathering along one dimension and scattering along another, effectively resharding tensors across devices. Key Implementation ComponentsCommunication Type Addition: Adds Core Runtime Implementation: The Detection Logic: AllToAll is detected when both producer and consumer tensors are sharded but on different logical dimensions ( Layout Handling: Includes an explicit TODO comment indicating that layout support for AllToAll needs fixes. Currently returns early in Areas Requiring Attention
Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant FD as FusionDefinition
participant Lower as lower_to_communication
participant Comm as Communication
participant NCCL as NCCL Backend
User->>FD: define AllToAll fusion with permute ops
User->>FD: set sharding on different dims
FD->>Lower: getCommunicationInfo(expr)
Lower->>Lower: detect p_sharded && c_sharded
Lower->>Lower: check if c_logical_id != p2c_map[p_logical_id]
Lower->>Lower: return AllToAll with nullptr sharded_ids
FD->>Lower: convertSingleOpToCommunication()
Lower->>Lower: getCommunicationLayout() returns early for AllToAll
Lower->>Lower: lowerToAllToAll()
Lower->>Comm: create Communication(AllToAll)
User->>FD: execute([input_tensor])
FD->>Comm: postSingleCommunication()
Comm->>Comm: postAllToAll()
Comm->>Comm: check isTvContiguous for input/output
Comm->>Comm: reshape input to split scattered dim
Comm->>Comm: permute to move DIDx(d) outermost
Comm->>Comm: flatten tensors with viewAsCompact
Comm->>NCCL: alltoall_base(output, input)
NCCL-->>Comm: Work handle
Comm-->>User: output tensor with resharded data
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| // input_tv = [DIDx(d), n/d, m, ...] | ||
| // output_tv = [n, DIDx(d), m/d, ...] | ||
| // `n`: gathered dimension | ||
| // `m`: scattered dimension | ||
| // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d, | ||
| // m/d, ...] such that alltoall_base splits across the `d` dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment describing the input/output tensor shapes appears to be inconsistent with the actual implementation below.
The comment states:
input_tv = [DIDx(d), n/d, m, ...]m: scattered dimension (at position 2)
But the code at lines 644-651 operates on dimension 1 (input_sizes.at(1)), treating it as the scattered dimension. Based on the test case in test_alltoall, the actual runtime input tensor shape is [n/d, m, k, ...] (after permutation), where:
- Position 0:
n/d(gathered dimension) - Position 1:
m(scattered dimension) - Position 2+: other dimensions
The comment should be updated to reflect the actual runtime tensor layout rather than the logical sharding representation, or clarify that it's describing logical sharding rather than physical memory layout.
| if (c_logical_id == p2c_map.at(p_logical_id)) { | ||
| fill_communication_info( | ||
| CommunicationType::SendRecv, p_logical_id, c_logical_id); | ||
| } else { | ||
| fill_communication_info( | ||
| CommunicationType::AllToAll, nullptr, nullptr); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting both p_sharded_id and c_sharded_id to nullptr for AllToAll communication may cause issues in other parts of the codebase that expect valid IterDomain pointers.
For example, getCommunicationLayout() (lines 476-521) is called with sharded_id as a parameter, and at line 493 it checks posInDomain(layout.allocation_domain(), sharded_id), which could fail if sharded_id is nullptr.
The current workaround is that AllToAll is explicitly handled in getCommunicationLayout() at line 489 to return early, but this creates a tight coupling and fragile dependency. Consider either:
- Storing the actual producer and consumer sharded IDs for AllToAll (even though they map to different logical dimensions)
- Adding explicit nullptr checks in all functions that consume CommunicationInfo
- Documenting this assumption clearly in the CommunicationInfo struct definition
| // TODO(prmishra): Fix the layout for AllToAll. | ||
| if (type == CommunicationType::Reduce || | ||
| type == CommunicationType::Allreduce || | ||
| type == CommunicationType::Broadcast || | ||
| type == CommunicationType::SendRecv) { | ||
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment indicates that the layout for AllToAll still needs to be fixed. This suggests the current implementation may not handle all cases correctly.
Based on the test comments in test_alltoall, the current approach uses permute operations to avoid non-adjacent splits that would fail stride validation. However, this workaround has limitations:
- It requires explicit permutation operations in the fusion definition
- It may introduce extra copies when making tensors contiguous
- It doesn't expose the true allocation domain structure to the fusion
Before merging, clarify:
- What specific layout issues remain to be addressed?
- Are there known failure cases with the current implementation?
- Is this TODO blocking or can it be addressed in a follow-up PR?
| NVF_CHECK( | ||
| input_sizes.at(1) % d == 0, | ||
| "Scattered dimension must be divisible by the team size"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message says "Scattered dimension must be divisible by the team size" but this check only validates one dimension.
For a complete AllToAll operation, both the scattered dimension (at position 1 in the runtime tensor) and the gathered dimension (at position 0) should be validated:
- The scattered dimension
input_sizes.at(1)must be divisible byd(currently checked ✓) - The gathered dimension should result in an output where the gathered size equals
d * input_sizes.at(0)
Consider adding validation for the output tensor dimensions as well to catch shape mismatches early, rather than relying on assertBuffersHaveSameSize at line 668 which only provides a generic error.
No description provided.